{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## FINE-TUNE BERT-BASE-JAPAENESE FOR CLASSIFICATION OF POLICY AREAS\n",
    "## Stefan Müller and Naofumi Fujimura\n",
    "## Campaign Communication and Legislative Leadership (PSRM)\n",
    "## Code compiled successfully on 2 February 2024 using Python 3.11.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "arm64\n",
      "Darwin Kernel Version 23.2.0: Wed Nov 15 21:55:06 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6020\n",
      "macOS-14.2.1-arm64-arm-64bit\n",
      "uname_result(system='Darwin', node='Stefans-Mac-Studio.local', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:55:06 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6020', machine='arm64')\n"
     ]
    }
   ],
   "source": [
    "## IMPORT MODULES\n",
    "\n",
    "# Note: you may be asked to install additional modules/dependencies\n",
    "\n",
    "# Dealing with Japanese tokenization\n",
    "import fugashi\n",
    "import ipadic\n",
    "\n",
    "## load general libraries\n",
    "import pandas as pd\n",
    "from datasets import load_dataset\n",
    "\n",
    "## load relevant functions from transformers library\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, pipeline, Trainer\n",
    "from sklearn.metrics import precision_recall_fscore_support, accuracy_score, recall_score, precision_score, f1_score\n",
    "\n",
    "## load platform to detect os and get information on machine\n",
    "import platform\n",
    "\n",
    "## get information on system\n",
    "import platform\n",
    "print(platform.machine())\n",
    "print(platform.version())\n",
    "print(platform.platform())\n",
    "print(platform.uname())\n",
    "\n",
    "# arm64\n",
    "# Darwin Kernel Version 23.2.0: Wed Nov 15 21:55:06 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6020\n",
    "# macOS-14.2.1-arm64-arm-64bit\n",
    "# uname_result(system='Darwin', node='Stefans-Mac-Studio.local', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:55:06 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6020', machine='arm64')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## FUNCTIONS\n",
    "\n",
    "# Load or download tokenizer\n",
    "def transform_labels(label):\n",
    "    \"\"\"\n",
    "    Transforms the label format for the HuggingFace Trainer.\n",
    "\n",
    "    Parameters:\n",
    "    - label (dict): A dictionary containing 'policy_area_num' as a key which\n",
    "                    represents the label for a given data example.\n",
    "\n",
    "    Returns:\n",
    "    - dict: A dictionary with a single key 'labels' and the transformed label as value.\n",
    "    \"\"\"\n",
    "    label = label['policy_area_num']\n",
    "    return {'labels': label}\n",
    "\n",
    "# define tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')\n",
    "\n",
    "def tokenize_data(example):\n",
    "    \"\"\"\n",
    "    Tokenizes a given text data example using a predefined tokenizer.\n",
    "\n",
    "    Parameters:\n",
    "    - example (dict): A dictionary containing 'text' as a key which\n",
    "                      represents the text to be tokenized.\n",
    "\n",
    "    Returns:\n",
    "    - dict: A dictionary containing tokenized data: input ids, attention masks, etc.\n",
    "    \"\"\"\n",
    "    return tokenizer(example['text'], padding='max_length', truncation=True, max_length=512)\n",
    "\n",
    "\n",
    "def compute_metrics(pred):\n",
    "    \"\"\"\n",
    "    Computes evaluation metrics including accuracy, precision, recall, and F1 score\n",
    "    for a given set of predictions.\n",
    "\n",
    "    Parameters:\n",
    "    - pred (Transformers.EvalPrediction): An object that contains two attributes:\n",
    "                                          'label_ids' representing true labels and\n",
    "                                          'predictions' representing the model's predictions.\n",
    "\n",
    "    Returns:\n",
    "    - dict: A dictionary containing the computed metrics (accuracy, precision, recall, and F1 score).\n",
    "    \"\"\"\n",
    "\n",
    "    # Extract the true labels from the provided predictions object\n",
    "    labels = pred.label_ids\n",
    "\n",
    "    # Identify the class (index) with the maximum prediction score for each sample\n",
    "    preds = pred.predictions.argmax(-1)\n",
    "\n",
    "    # Calculate macro-averaged precision, recall, and F1 score.\n",
    "    # The \"macro\" average computes the metric independently for each class and then takes the average.\n",
    "    # Setting zero_division=1 ensures that precision is set to 1.0 in the case where a class has no predicted samples.\n",
    "    Precision, Recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro', zero_division=1)\n",
    "\n",
    "    # Calculate the overall accuracy of the predictions\n",
    "    acc = accuracy_score(labels, preds)\n",
    "\n",
    "    # Return a dictionary containing the calculated metrics\n",
    "    return {\n",
    "        'accuracy': acc,\n",
    "        'f1': f1,\n",
    "        'precision': Precision,\n",
    "        'recall': Recall\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__`Tokenization with the Functions`__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 10645.44it/s]\n",
      "Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 404.54it/s]\n",
      "Generating train split: 2000 examples [00:00, 173777.93 examples/s]\n",
      "Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 6364.65it/s]\n",
      "Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 592.75it/s]\n",
      "Generating train split: 500 examples [00:00, 93614.50 examples/s]\n",
      "Map: 100%|██████████| 2000/2000 [00:00<00:00, 7742.72 examples/s]\n",
      "Map: 100%|██████████| 2000/2000 [00:00<00:00, 34739.18 examples/s]\n",
      "Map: 100%|██████████| 500/500 [00:00<00:00, 9740.88 examples/s]\n",
      "Map: 100%|██████████| 500/500 [00:00<00:00, 28638.66 examples/s]\n"
     ]
    }
   ],
   "source": [
    "# load training dataset\n",
    "dataset_train = load_dataset(\"csv\", data_files = \"data_sentences_train.csv\", encoding='utf-8')\n",
    "dataset_train = dataset_train[\"train\"]\n",
    "\n",
    "# load evaluation dataset during training\n",
    "dataset_eval = load_dataset(\"csv\", data_files = \"data_sentences_eval.csv\", encoding='utf-8')\n",
    "\n",
    "# get datasets into correct format\n",
    "dataset_train = dataset_train.map(tokenize_data, batched=True)\n",
    "dataset_train = dataset_train.map(transform_labels, remove_columns = [\"text\", \"policy_area\", \"policy_area_num\"])\n",
    "dataset_train.set_format(type='torch', columns=['input_ids', 'ntoken', 'token_type_ids', 'attention_mask', 'labels'])\n",
    "\n",
    "dataset_eval = dataset_eval.map(tokenize_data, batched=True)\n",
    "dataset_eval = dataset_eval.map(transform_labels, remove_columns = [\"text\", \"policy_area\", \"policy_area_num\"])\n",
    "dataset_eval.set_format(type='torch', columns=['input_ids', 'ntoken', 'token_type_ids', 'attention_mask', 'labels'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## PREPARE AND FINE-TUNE MODEL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Training arguments: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments\n",
    "- https://huggingface.co/jenspt/byt5_ft_all_clean_data/blob/main/README.md?code=true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['ntoken', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
       "    num_rows: 2000\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# make sure training dataset consists of 2000 statements\n",
    "dataset_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['ntoken', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
       "        num_rows: 500\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# make sure evaluation dataset consists of 500 statements\n",
    "dataset_eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "# load transformer model\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"cl-tohoku/bert-base-japanese\", num_labels=12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 50/315 [46:08<1:41:56, 23.08s/it]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 1.8615, 'learning_rate': 4.491525423728814e-05, 'epoch': 0.79}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                  \n",
      " 16%|█▌        | 50/315 [46:56<1:41:56, 23.08s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_train_loss': 1.304336667060852, 'eval_train_accuracy': 0.648, 'eval_train_f1': 0.45120958073833517, 'eval_train_precision': 0.7415381906116029, 'eval_train_recall': 0.4688793226996551, 'eval_train_runtime': 47.483, 'eval_train_samples_per_second': 10.53, 'eval_train_steps_per_second': 0.337, 'epoch': 0.79}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 32%|███▏      | 100/315 [1:04:34<1:53:42, 31.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.8983, 'learning_rate': 3.644067796610169e-05, 'epoch': 1.59}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                     \n",
      " 32%|███▏      | 100/315 [1:06:29<1:53:42, 31.73s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_train_loss': 0.7586835622787476, 'eval_train_accuracy': 0.774, 'eval_train_f1': 0.7132412818467397, 'eval_train_precision': 0.7685024835511709, 'eval_train_recall': 0.6845087509050104, 'eval_train_runtime': 115.1321, 'eval_train_samples_per_second': 4.343, 'eval_train_steps_per_second': 0.139, 'epoch': 1.59}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 48%|████▊     | 150/315 [1:33:54<1:33:22, 33.96s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.4974, 'learning_rate': 2.7966101694915255e-05, 'epoch': 2.38}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                     \n",
      " 48%|████▊     | 150/315 [1:35:50<1:33:22, 33.96s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_train_loss': 0.7047998309135437, 'eval_train_accuracy': 0.778, 'eval_train_f1': 0.7232635256589767, 'eval_train_precision': 0.7565823096448409, 'eval_train_recall': 0.7038271973893777, 'eval_train_runtime': 116.2432, 'eval_train_samples_per_second': 4.301, 'eval_train_steps_per_second': 0.138, 'epoch': 2.38}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 63%|██████▎   | 200/315 [2:07:54<1:19:57, 41.72s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.2926, 'learning_rate': 1.9491525423728814e-05, 'epoch': 3.17}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                     \n",
      " 63%|██████▎   | 200/315 [2:09:37<1:19:57, 41.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_train_loss': 0.7205259799957275, 'eval_train_accuracy': 0.788, 'eval_train_f1': 0.7295922615639051, 'eval_train_precision': 0.74915738454167, 'eval_train_recall': 0.7239879369115302, 'eval_train_runtime': 102.9221, 'eval_train_samples_per_second': 4.858, 'eval_train_steps_per_second': 0.155, 'epoch': 3.17}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 79%|███████▉  | 250/315 [2:56:21<2:25:03, 133.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.1731, 'learning_rate': 1.1016949152542374e-05, 'epoch': 3.97}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                      \n",
      " 79%|███████▉  | 250/315 [2:57:10<2:25:03, 133.89s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_train_loss': 0.7313029170036316, 'eval_train_accuracy': 0.79, 'eval_train_f1': 0.7378607058009155, 'eval_train_precision': 0.7512031277020959, 'eval_train_recall': 0.73333905404118, 'eval_train_runtime': 49.6467, 'eval_train_samples_per_second': 10.071, 'eval_train_steps_per_second': 0.322, 'epoch': 3.97}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 95%|█████████▌| 300/315 [3:31:45<13:07, 52.50s/it]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 0.1033, 'learning_rate': 2.5423728813559323e-06, 'epoch': 4.76}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                   \n",
      " 95%|█████████▌| 300/315 [3:36:02<13:07, 52.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_train_loss': 0.7411648035049438, 'eval_train_accuracy': 0.784, 'eval_train_f1': 0.7347075487478395, 'eval_train_precision': 0.7507532844733983, 'eval_train_recall': 0.7272302989174793, 'eval_train_runtime': 256.8317, 'eval_train_samples_per_second': 1.947, 'eval_train_steps_per_second': 0.062, 'epoch': 4.76}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 315/315 [3:39:59<00:00, 41.90s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train_runtime': 13199.9976, 'train_samples_per_second': 0.758, 'train_steps_per_second': 0.024, 'train_loss': 0.6120190942098224, 'epoch': 5.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=315, training_loss=0.6120190942098224, metrics={'train_runtime': 13199.9976, 'train_samples_per_second': 0.758, 'train_steps_per_second': 0.024, 'train_loss': 0.6120190942098224, 'epoch': 5.0})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# training arguments\n",
    "training_args = TrainingArguments(\n",
    "    # output_dir='./model',\n",
    "    do_train=True,\n",
    "    do_eval=True,\n",
    "    num_train_epochs=5,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    warmup_steps=20,\n",
    "    weight_decay=0.01,\n",
    "    logging_strategy=\"steps\",\n",
    "    logging_dir='./logs',\n",
    "    logging_steps=50,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    load_best_model_at_end=True,\n",
    "    use_cpu=True\n",
    ")\n",
    "\n",
    "# fine-tune the model (takes several hours)\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=dataset_train,\n",
    "    eval_dataset=dataset_eval,\n",
    "    compute_metrics=compute_metrics\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save model to drive\n",
    "trainer.save_model(\"bert_manifestos\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BertForSequenceClassification(\n",
      "  (bert): BertModel(\n",
      "    (embeddings): BertEmbeddings(\n",
      "      (word_embeddings): Embedding(32000, 768, padding_idx=0)\n",
      "      (position_embeddings): Embedding(512, 768)\n",
      "      (token_type_embeddings): Embedding(2, 768)\n",
      "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "      (dropout): Dropout(p=0.1, inplace=False)\n",
      "    )\n",
      "    (encoder): BertEncoder(\n",
      "      (layer): ModuleList(\n",
      "        (0-11): 12 x BertLayer(\n",
      "          (attention): BertAttention(\n",
      "            (self): BertSelfAttention(\n",
      "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "            (output): BertSelfOutput(\n",
      "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (intermediate): BertIntermediate(\n",
      "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            (intermediate_act_fn): GELUActivation()\n",
      "          )\n",
      "          (output): BertOutput(\n",
      "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (pooler): BertPooler(\n",
      "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "      (activation): Tanh()\n",
      "    )\n",
      "  )\n",
      "  (dropout): Dropout(p=0.1, inplace=False)\n",
      "  (classifier): Linear(in_features=768, out_features=12, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Load fine-tuned model\n",
    "# model = AutoModelForSequenceClassification.from_pretrained(\"bert_manifestos\", num_labels = 12)\n",
    "\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     policy_area_num                          policy_area  \\\n",
      "0                 10                       No policy area   \n",
      "1                  7          Health, Labour, and Welfare   \n",
      "2                 10                       No policy area   \n",
      "3                  1                Committees on Cabinet   \n",
      "4                 10                       No policy area   \n",
      "..               ...                                  ...   \n",
      "480               10                       No policy area   \n",
      "481               10                       No policy area   \n",
      "482               10                       No policy area   \n",
      "483                8  Internal Affairs and Communications   \n",
      "484               10                       No policy area   \n",
      "\n",
      "                               text  ntoken  \n",
      "0                               三〇代       3  \n",
      "1             医療は勤務医の処遇を改善し、救急医療や病院      14  \n",
      "2                ％平成26年9月自由民主党幹事長代理      12  \n",
      "3                      老朽化対策、さらなる防災       6  \n",
      "4    世界に誇れる、技術と文化を生かした浜松のまちづくりを進めます      19  \n",
      "..                              ...     ...  \n",
      "480                           （２）伝統       4  \n",
      "481                            「安心」       3  \n",
      "482                         庶民派ダントツ       3  \n",
      "483       ◆道州制の導入を視野に入れた特色ある地方自治の推進      16  \n",
      "484                     §http://www       6  \n",
      "\n",
      "[485 rows x 4 columns]\n",
      "485\n"
     ]
    }
   ],
   "source": [
    "# load the CSV file with held-out test data\n",
    "df = pd.read_csv('data_sentences_test.csv')\n",
    "\n",
    "# Ensure that the column containing the text is of type str\n",
    "df['text'] = df['text'].astype(str)\n",
    "\n",
    "# check number of observations (should be 485)\n",
    "num_rows = df.shape[0]\n",
    "print(num_rows)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# detect OS\n",
    "current_os = platform.system()\n",
    "# active device based on OS\n",
    "if current_os == \"Darwin\":\n",
    "    # specify device as mps\n",
    "    device = \"mps\"\n",
    "else:\n",
    "    # check if gpu is available, if yes use cuda, if not stick to cpu\n",
    "    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "    # must be 'cuda:0', not just 'cuda' due to a bug in transformers library, see: https://github.com/deepset-ai/haystack/issues/3160\n",
    "\n",
    "# check which device is used (GPU or CPU)\n",
    "# if device = mps, then GPU is used on an Apple Silicon Mac or MacBook\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a pipeline for text processing\n",
    "classifier = pipeline(\"text-classification\",\n",
    "                      model = model,\n",
    "                      max_length = 512,\n",
    "                      batch_size = 64,\n",
    "                      device = device,\n",
    "                      padding = 'max_length',\n",
    "                      truncation = True,\n",
    "                      tokenizer = tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# apply classifier to held-out test data\n",
    "text_data = df['text'].tolist()\n",
    "\n",
    "results = pd.DataFrame(classifier(text_data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove the 'LABEL_' prefix from the 'label' column\n",
    "results['label'] = results['label'].str.replace('LABEL_', '')\n",
    "\n",
    "df_classified = df\n",
    "\n",
    "# Add the results back to the original DataFrame\n",
    "df_classified['label'] = results['label']\n",
    "df_classified['score'] = results['score']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "    label     score\n",
      "0      10  0.994375\n",
      "1       7  0.972203\n",
      "2      10  0.991938\n",
      "3       1  0.655717\n",
      "4      10  0.964203\n",
      "..    ...       ...\n",
      "480    10  0.592376\n",
      "481    10  0.992377\n",
      "482    10  0.966443\n",
      "483     8  0.834104\n",
      "484    10  0.995495\n",
      "\n",
      "[485 rows x 2 columns]\n",
      "     policy_area_num                          policy_area  \\\n",
      "0                 10                       No policy area   \n",
      "1                  7          Health, Labour, and Welfare   \n",
      "2                 10                       No policy area   \n",
      "3                  1                Committees on Cabinet   \n",
      "4                 10                       No policy area   \n",
      "..               ...                                  ...   \n",
      "480               10                       No policy area   \n",
      "481               10                       No policy area   \n",
      "482               10                       No policy area   \n",
      "483                8  Internal Affairs and Communications   \n",
      "484               10                       No policy area   \n",
      "\n",
      "                               text  ntoken label     score  \n",
      "0                               三〇代       3    10  0.994375  \n",
      "1             医療は勤務医の処遇を改善し、救急医療や病院      14     7  0.972203  \n",
      "2                ％平成26年9月自由民主党幹事長代理      12    10  0.991938  \n",
      "3                      老朽化対策、さらなる防災       6     1  0.655717  \n",
      "4    世界に誇れる、技術と文化を生かした浜松のまちづくりを進めます      19    10  0.964203  \n",
      "..                              ...     ...   ...       ...  \n",
      "480                           （２）伝統       4    10  0.592376  \n",
      "481                            「安心」       3    10  0.992377  \n",
      "482                         庶民派ダントツ       3    10  0.966443  \n",
      "483       ◆道州制の導入を視野に入れた特色ある地方自治の推進      16     8  0.834104  \n",
      "484                     §http://www       6    10  0.995495  \n",
      "\n",
      "[485 rows x 6 columns]\n"
     ]
    }
   ],
   "source": [
    "print(df_classified)\n",
    "\n",
    "# save the DataFrame with results back to a new CSV file\n",
    "df_classified.to_csv('data_test_predicted_bert.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
